import yaml
import logging
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.addHandler(handler)

import os
import torch
from datasets import DatasetDict
import argparse
from main.wmpatch import GTWatermark, GTWatermarkMulti
from main.dataset import get_dataset_base


torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  # This can be set to True for faster performance, but may not be deterministic

def main(args):
    # dataloader
    logging.info(f'===== Load Config =====')
    with open(args.cfg_path, 'r') as file:
        cfgs = yaml.safe_load(file)
    logging.info(cfgs)

    device = torch.device('cuda')

    # init wm pipeline
    wm_pipe = GTWatermark(device, shape=(args.batch_size, 4, 64, 64), w_channel=cfgs['w_channel'], w_radius=cfgs['w_radius'], generator=torch.Generator(device).manual_seed(cfgs['w_seed']))

    # update train dataset
    def add_tensors(example):
        wm_pipe.generate_watermark()
        real = wm_pipe.gt_patch.real.squeeze()
        imag = wm_pipe.gt_patch.imag.squeeze()
        return {'gt_patch_real': real, 'gt_patch_imag': imag, 'watermarking_mask': wm_pipe.watermarking_mask}

    train_dataset = get_dataset_base(args.dataset_path, args.dataset, is_train=True)
    new_train_dataset = train_dataset.map(add_tensors, batched=False)


    test_dataset = get_dataset_base(args.dataset_path, args.dataset, is_train=False)
    new_test_dataset = test_dataset.map(add_tensors, batched=False)


    joint_train_test_data = DatasetDict({'train': new_train_dataset, 'test': new_test_dataset})

    save_path = os.path.join(args.dataset_path, args.dataset + '_water')
    joint_train_test_data.save_to_disk(save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument ('--dataset', default='diffusiondb', choices=['coco', 'diffusiondb', 'wikiart'])
    parser.add_argument ('--dataset_path', default='/localhome/data/datasets/watermarking')
    parser.add_argument ('--seed', default=0, type=int)
    parser.add_argument ('--cfg_path', default='./example/config/config.yaml')
    parser.add_argument ('--batch_size', default=1, type=int)

    args = parser.parse_args()

    # set seed
    torch.cuda.manual_seed_all(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    
    main(args)

